{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DfPPQ6ztJhv4"
   },
   "source": [
    "# Faster R-CNN with FPN training\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "text",
    "id": "bX0rqK-A3Nbl"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from tqdm.notebook import tqdm\n",
    "import torch\n",
    "\n",
    "sys.path.insert(0, os.path.abspath('../../src/faster_rcnn_fpn'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gdIDvMw_Z4lY"
   },
   "source": [
    "Let's have a look at the dataset and how it is layed down.\n",
    "\n",
    "Here is one example of an image in the dataset, with its corresponding instance segmentation mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LDjuVFgexFfh"
   },
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import os.path as osp\n",
    "\n",
    "data_root_dir = '../../data/MOT17Det'\n",
    "output_dir = \"../../output/faster_rcnn_fpn/faster_rcnn_fpn_training_mot_17_split_09\"\n",
    "\n",
    "if not osp.exists(output_dir):\n",
    "    os.makedirs(output_dir)\n",
    "\n",
    "# Image.open(osp.join(data_root_dir, 'train/MOT17-02/img1/000001.jpg'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5Sd4jlGp2eLm"
   },
   "source": [
    "## Defining the Dataset\n",
    "\n",
    "The [torchvision reference scripts for training object detection, instance segmentation and person keypoint detection](https://github.com/pytorch/vision/tree/v0.3.0/references/detection) allows for easily supporting adding new custom datasets.\n",
    "The dataset should inherit from the standard `torch.utils.data.Dataset` class, and implement `__len__` and `__getitem__`.\n",
    "\n",
    "The only specificity that we require is that the dataset `__getitem__` should return:\n",
    "\n",
    "* image: a PIL Image of size (H, W)\n",
    "* target: a dict containing the following fields\n",
    "    * `boxes` (`FloatTensor[N, 4]`): the coordinates of the `N` bounding boxes in `[x0, y0, x1, y1]` format, ranging from `0` to `W` and `0` to `H`\n",
    "    * `labels` (`Int64Tensor[N]`): the label for each bounding box\n",
    "    * `image_id` (`Int64Tensor[1]`): an image identifier. It should be unique between all the images in the dataset, and is used during evaluation\n",
    "    * `area` (`Tensor[N]`): The area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes.\n",
    "    * `iscrowd` (`UInt8Tensor[N]`): instances with `iscrowd=True` will be ignored during evaluation.\n",
    "    * (optionally) `masks` (`UInt8Tensor[N, H, W]`): The segmentation masks for each one of the objects\n",
    "    * (optionally) `keypoints` (`FloatTensor[N, K, 3]`): For each one of the `N` objects, it contains the `K` keypoints in `[x, y, visibility]` format, defining the object. `visibility=0` means that the keypoint is not visible. Note that for data augmentation, the notion of flipping a keypoint is dependent on the data representation, and you should probably adapt `references/detection/transforms.py` for your new keypoint representation\n",
    "\n",
    "If your model returns the above methods, they will make it work for both training and evaluation, and will use the evaluation scripts from pycocotools.\n",
    "\n",
    "Additionally, if you want to use aspect ratio grouping during training (so that each batch only contains images with similar aspect ratio), then it is recommended to also implement a `get_height_and_width` method, which returns the height and the width of the image. If this method is not provided, we query all elements of the dataset via `__getitem__` , which loads the image in memory and is slower than if a custom method is provided.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "C9Ee5NV54Dmj"
   },
   "source": [
    "So each image has a corresponding segmentation mask, where each color correspond to a different instance. Let's write a `torch.utils.data.Dataset` class for this dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3R9PRiGaasyM"
   },
   "outputs": [],
   "source": [
    "from mot_data import MOTObjDetect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "AZNTZGnUitoE"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import transforms as T\n",
    "\n",
    "def plot(img, boxes):\n",
    "  fig, ax = plt.subplots(1, dpi=96)\n",
    "\n",
    "  img = img.mul(255).permute(1, 2, 0).byte().numpy()\n",
    "  width, height, _ = img.shape\n",
    "    \n",
    "  ax.imshow(img, cmap='gray')\n",
    "  fig.set_size_inches(width / 80, height / 80)\n",
    "\n",
    "  for box in boxes:\n",
    "      rect = plt.Rectangle(\n",
    "        (box[0], box[1]),\n",
    "        box[2] - box[0],\n",
    "        box[3] - box[1],\n",
    "        fill=False,\n",
    "        linewidth=1.0)\n",
    "      ax.add_patch(rect)\n",
    "\n",
    "  plt.axis('off')\n",
    "  plt.show()\n",
    "\n",
    "dataset = MOTObjDetect(osp.join(data_root_dir, 'train'), split_seqs=['MOT17-09'])\n",
    "print(len(dataset))\n",
    "img, target = dataset[-5]\n",
    "img, target = T.ToTensor()(img, target)\n",
    "plot(img, target['boxes'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "J6f3ZOTJ4Km9"
   },
   "source": [
    "That's all for the dataset. Let's see how the outputs are structured for this dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YjNHjVMOyYlH"
   },
   "outputs": [],
   "source": [
    "import torchvision\n",
    "from torchvision.models.detection.faster_rcnn import FastRCNNPredictor\n",
    "\n",
    "      \n",
    "def get_detection_model(num_classes):\n",
    "    # load an instance segmentation model pre-trained on COCO\n",
    "    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)\n",
    "\n",
    "    # get the number of input features for the classifier\n",
    "    in_features = model.roi_heads.box_predictor.cls_score.in_features\n",
    "    # replace the pre-trained head with a new one\n",
    "    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)\n",
    "    model.roi_heads.nms_thresh = 0.3\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3YFJGJxk6XEs"
   },
   "source": [
    "DATASETS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "l79ivkwKy357"
   },
   "outputs": [],
   "source": [
    "from engine import train_one_epoch, evaluate\n",
    "import utils\n",
    "\n",
    "def get_transform(train):\n",
    "    transforms = []\n",
    "    # converts the image, a PIL image, into a PyTorch Tensor\n",
    "    transforms.append(T.ToTensor())\n",
    "    if train:\n",
    "        # during training, randomly flip the training images\n",
    "        # and ground-truth for data augmentation\n",
    "        transforms.append(T.RandomHorizontalFlip(0.5))\n",
    "    return T.Compose(transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "a5dGaIezze3y"
   },
   "outputs": [],
   "source": [
    "train_split_seqs = test_split_seqs = None\n",
    "\n",
    "train_split_seqs = ['MOT17-02', 'MOT17-04', 'MOT17-05', 'MOT17-09', 'MOT17-10', 'MOT17-11', 'MOT17-13']\n",
    "test_split_seqs = ['MOT17-09']\n",
    "for seq in test_split_seqs:\n",
    "    train_split_seqs.remove(seq)\n",
    "\n",
    "# use our dataset and defined transformations\n",
    "dataset = MOTObjDetect(\n",
    "    osp.join(data_root_dir, 'train'),\n",
    "    get_transform(train=True),\n",
    "    split_seqs=train_split_seqs)\n",
    "dataset_no_random = MOTObjDetect(\n",
    "    osp.join(data_root_dir, 'train'),\n",
    "    get_transform(train=False),\n",
    "    split_seqs=train_split_seqs)\n",
    "# dataset_test = MOTObjDetect(\n",
    "#     osp.join(data_root_dir, 'test'),\n",
    "#     get_transform(train=False))\n",
    "dataset_test = MOTObjDetect(\n",
    "    osp.join(data_root_dir, 'train'),\n",
    "    get_transform(train=False),\n",
    "    split_seqs=test_split_seqs)\n",
    "\n",
    "# split the dataset in train and test set\n",
    "torch.manual_seed(1)\n",
    "# indices = torch.randperm(len(dataset)).tolist()\n",
    "# dataset = torch.utils.data.Subset(dataset, indices[:-50])\n",
    "# dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])\n",
    "\n",
    "# define training and validation data loaders\n",
    "data_loader = torch.utils.data.DataLoader(\n",
    "    dataset, batch_size=2, shuffle=True, num_workers=4,\n",
    "    collate_fn=utils.collate_fn)\n",
    "data_loader_no_random = torch.utils.data.DataLoader(\n",
    "    dataset_no_random, batch_size=2, shuffle=False, num_workers=4,\n",
    "    collate_fn=utils.collate_fn)\n",
    "\n",
    "data_loader_test = torch.utils.data.DataLoader(\n",
    "    dataset_test, batch_size=2, shuffle=False, num_workers=4,\n",
    "    collate_fn=utils.collate_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "L5yvZUprj4ZN"
   },
   "source": [
    "INIT MODEL AND OPTIM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "zoenkCj18C4h"
   },
   "outputs": [],
   "source": [
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "\n",
    "# get the model using our helper function\n",
    "model = get_detection_model(dataset.num_classes)\n",
    "# move model to the right device\n",
    "model.to(device)\n",
    "\n",
    "# model_state_dict = torch.load(osp.join(output_dir, 'model_epoch_30.model'))\n",
    "# model.load_state_dict(model_state_dict)\n",
    "\n",
    "# construct an optimizer\n",
    "params = [p for p in model.parameters() if p.requires_grad]\n",
    "optimizer = torch.optim.SGD(params, lr=0.00001,\n",
    "                            momentum=0.9, weight_decay=0.0005)\n",
    "\n",
    "# and a learning rate scheduler which decreases the learning rate by\n",
    "# 10x every 3 epochs\n",
    "lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,\n",
    "                                               step_size=10,\n",
    "                                               gamma=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "XAd56lt4kDxc"
   },
   "source": [
    "TRAINING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "18MtwEFcxLxD",
    "outputId": "c2f122ad-19fc-4258-ba2c-d256c4a61384"
   },
   "outputs": [],
   "source": [
    "def evaluate_and_write_result_files(model, data_loader):\n",
    "  print(f'EVAL {data_loader.dataset}')\n",
    "  model.eval()\n",
    "  results = {}\n",
    "  for imgs, targets in tqdm(data_loader):\n",
    "    imgs = [img.to(device) for img in imgs]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        preds = model(imgs)\n",
    "    \n",
    "    for pred, target in zip(preds, targets):\n",
    "        results[target['image_id'].item()] = {'boxes': pred['boxes'].cpu(),\n",
    "                                              'scores': pred['scores'].cpu()}\n",
    "        \n",
    "  data_loader.dataset.write_results_files(results, output_dir)\n",
    "  data_loader.dataset.print_eval(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "at-h4OWK0aoc"
   },
   "outputs": [],
   "source": [
    "num_epochs = 30\n",
    "\n",
    "# evaluate_and_write_result_files(model, data_loader_no_random)\n",
    "evaluate_and_write_result_files(model, data_loader_test)\n",
    "\n",
    "for epoch in range(1, num_epochs + 1):\n",
    "    print(f'TRAIN {data_loader.dataset}')\n",
    "    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=200)\n",
    "    \n",
    "    # update the learning rate\n",
    "    lr_scheduler.step()\n",
    "    \n",
    "    # evaluate on the test dataset\n",
    "    if epoch % 2 == 0:\n",
    "      evaluate_and_write_result_files(model, data_loader_test)\n",
    "      torch.save(model.state_dict(), osp.join(output_dir, f\"model_epoch_{epoch}.model\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Z6mYGFLxkO8F"
   },
   "source": [
    "QUALITATIVE TESTING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YHwIdxH76uPj"
   },
   "outputs": [],
   "source": [
    "# pick one image from the test set\n",
    "data_loader = torch.utils.data.DataLoader(\n",
    "    dataset_no_random, batch_size=1, shuffle=False, num_workers=4,\n",
    "    collate_fn=utils.collate_fn)\n",
    "\n",
    "for imgs, target in data_loader:\n",
    "    print(dataset._img_paths[0])\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        prediction = model([imgs[0].to(device)])[0]\n",
    "    \n",
    "    plot(imgs[0], prediction['boxes'])\n",
    "    plot(imgs[0], target[0]['boxes'])\n",
    "    break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "faster_rcnn_fpn_training_mot_17",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}